#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ----------------------------------------------------------------------------
# This program is free software, you can redistribute it and/or modify it.
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ----------------------------------------------------------------------------

import sys
import os
import numpy as np
import torch
import tensorflow as tf

bf16 = tf.bfloat16.as_numpy_dtype
np.random.seed(0)
case_list = [
    {'dtype': 'fp16', 'grp_dtype': np.int32, 'grad_y_shape': [40, 10], 'group_idx': [0, 3, 7, 10]},
    {'dtype': 'fp32', 'grp_dtype': np.int32, 'grad_y_shape': [16, 2000, 128], 'group_idx': 'none'},
    {'dtype': 'bf16', 'grp_dtype': np.int32, 'grad_y_shape': [100, 256], 'group_idx': [40, 60, 100]},
    {'dtype': 'fp16', 'grp_dtype': np.int32, 'grad_y_shape': [1968, 458], 'group_idx': [737, 1492, 1968]},
    {'dtype': 'fp32', 'grp_dtype': np.int32, 'grad_y_shape': [3200, 399], 'group_idx': [1190, 1490, 2000, 3200]},
    {'dtype': 'fp16', 'grp_dtype': np.int32, 'grad_y_shape': [8875, 1228], 'group_idx': [737, 1492, 1968, 2422, 2834, 3220, 3663, 4411, \
                                                                  4917, 5571, 6200, 6729, 7324, 7894, 8434, 8875]},
    {'dtype': 'fp16', 'grp_dtype': np.int32, 'grad_y_shape': [100, 256], 'group_idx': [40, 40, 100]},
    {'dtype': 'fp32', 'grp_dtype': np.int32, 'grad_y_shape': [1968, 2560], 'group_idx': [1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, \
                                                                    40, 43, 46, 49, 52, 55, 58, 61, 64, 67, 70, 73, 76, \
                                                                    79, 82, 85, 88, 91, 94, 97, 100, 103, 106, 109, \
                                                                    112, 115, 118, 121, 124, 127, 130, 133, 136, 139, \
                                                                    142, 145, 148, 151, 154, 157, 160, 163, 166, 169, \
                                                                    172, 175, 178, 181, 184, 187, 190, 193, 196, 199, \
                                                                    202, 205, 208, 211, 214, 217, 220, 223, 226, 229, \
                                                                    232, 235, 238, 241, 244, 247, 250, 253, 256, 259, \
                                                                    262, 265, 268, 271, 274, 277, 280, 283, 286, 289, \
                                                                    292, 295, 298, 301, 304, 307, 310, 313, 316, 319, \
                                                                    322, 325, 328, 331, 334, 337, 340, 343, 346, 349, \
                                                                    352, 355, 358, 361, 364, 367, 370, 373, 376, 379, \
                                                                    382, 385, 388, 391, 394, 397, 400, 403, 406, 409, \
                                                                    412, 415, 418, 421, 424, 427, 430, 433, 436, 439, \
                                                                    442, 445, 448, 451, 454, 457, 460, 463, 466, 469, \
                                                                    472, 475, 478, 481, 484, 487, 490, 493, 496, 499, \
                                                                    502, 505, 508, 511, 514, 517, 520, 523, 526, 529, \
                                                                    532, 535, 538, 541, 544, 547, 550, 553, 556, 559, \
                                                                    562, 565, 568, 571, 574, 577, 580, 583, 586, 589, \
                                                                    592, 595, 598, 601, 604, 607, 610, 613, 616, 619, \
                                                                    622, 625, 628, 631, 634, 637, 640, 643, 646, 649, \
                                                                    652, 655, 658, 737, 1492, 1592, 1968]},
    {'dtype': 'bf16', 'grp_dtype': np.int64, 'grad_y_shape': [100, 256], 'group_idx': [40, 60, 100]},
    {'dtype': 'fp16', 'grp_dtype': np.int64, 'grad_y_shape': [1968, 458], 'group_idx': [737, 1492, 1968]},
    {'dtype': 'fp32', 'grp_dtype': np.int64, 'grad_y_shape': [3200, 399], 'group_idx': [1190, 1490, 2000, 3200]},
    {'dtype': 'fp16', 'grp_dtype': np.int64, 'grad_y_shape': [8875, 1228], 'group_idx': [737, 1492, 1968, 2422, 2834, 3220, 3663, 4411, \
                                                                  4917, 5571, 6200, 6729, 7324, 7894, 8434, 8875]},
    {'dtype': 'fp16', 'grp_dtype': np.int64, 'grad_y_shape': [100, 256], 'group_idx': [40, 40, 100]},
    {'dtype': 'fp32', 'grp_dtype': np.int64, 'grad_y_shape': [1968, 2560], 'group_idx': [1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, \
                                                                    40, 43, 46, 49, 52, 55, 58, 61, 64, 67, 70, 73, 76, \
                                                                    79, 82, 85, 88, 91, 94, 97, 100, 103, 106, 109, \
                                                                    112, 115, 118, 121, 124, 127, 130, 133, 136, 139, \
                                                                    142, 145, 148, 151, 154, 157, 160, 163, 166, 169, \
                                                                    172, 175, 178, 181, 184, 187, 190, 193, 196, 199, \
                                                                    202, 205, 208, 211, 214, 217, 220, 223, 226, 229, \
                                                                    232, 235, 238, 241, 244, 247, 250, 253, 256, 259, \
                                                                    262, 265, 268, 271, 274, 277, 280, 283, 286, 289, \
                                                                    292, 295, 298, 301, 304, 307, 310, 313, 316, 319, \
                                                                    322, 325, 328, 331, 334, 337, 340, 343, 346, 349, \
                                                                    352, 355, 358, 361, 364, 367, 370, 373, 376, 379, \
                                                                    382, 385, 388, 391, 394, 397, 400, 403, 406, 409, \
                                                                    412, 415, 418, 421, 424, 427, 430, 433, 436, 439, \
                                                                    442, 445, 448, 451, 454, 457, 460, 463, 466, 469, \
                                                                    472, 475, 478, 481, 484, 487, 490, 493, 496, 499, \
                                                                    502, 505, 508, 511, 514, 517, 520, 523, 526, 529, \
                                                                    532, 535, 538, 541, 544, 547, 550, 553, 556, 559, \
                                                                    562, 565, 568, 571, 574, 577, 580, 583, 586, 589, \
                                                                    592, 595, 598, 601, 604, 607, 610, 613, 616, 619, \
                                                                    622, 625, 628, 631, 634, 637, 640, 643, 646, 649, \
                                                                    652, 655, 658, 737, 1492, 1592, 1968]},
    {'dtype': 'fp16', 'grp_dtype': np.int64, 'grad_y_shape': [1968, 2560], 'group_idx': [1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, \
                                                                    40, 43, 46, 49, 52, 55, 58, 61, 64, 67, 70, 73, 76, \
                                                                    79, 82, 85, 88, 91, 94, 97, 100, 103, 106, 109, \
                                                                    112, 115, 118, 121, 124, 127, 130, 133, 136, 139, \
                                                                    142, 145, 148, 151, 154, 157, 160, 163, 166, 169, \
                                                                    172, 175, 178, 181, 184, 187, 190, 193, 196, 199, \
                                                                    202, 205, 208, 211, 214, 217, 220, 223, 226, 229, \
                                                                    232, 235, 238, 241, 244, 247, 250, 253, 256, 259, \
                                                                    262, 265, 268, 271, 274, 277, 280, 283, 286, 289, \
                                                                    292, 295, 298, 301, 304, 307, 310, 313, 316, 319, \
                                                                    322, 325, 328, 331, 334, 337, 340, 343, 346, 349, \
                                                                    352, 355, 358, 361, 364, 367, 370, 373, 376, 379, \
                                                                    382, 385, 388, 391, 394, 397, 400, 403, 406, 409, \
                                                                    412, 415, 418, 421, 424, 427, 430, 433, 436, 439, \
                                                                    442, 445, 448, 451, 454, 457, 460, 463, 466, 469, \
                                                                    472, 475, 478, 481, 484, 487, 490, 493, 496, 499, \
                                                                    502, 505, 508, 511, 514, 517, 520, 523, 526, 529, \
                                                                    532, 535, 538, 541, 544, 547, 550, 553, 556, 559, \
                                                                    562, 565, 568, 571, 574, 577, 580, 583, 586, 589, \
                                                                    592, 595, 598, 601, 604, 607, 610, 613, 616, 619, \
                                                                    622, 625, 628, 631, 634, 637, 640, 643, 646, 649, \
                                                                    652, 655, 658, 737, 1492, 1592, 1968]},
    {'dtype': 'bf16', 'grp_dtype': np.int64, 'grad_y_shape': [1968, 2560], 'group_idx': [1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, \
                                                                    40, 43, 46, 49, 52, 55, 58, 61, 64, 67, 70, 73, 76, \
                                                                    79, 82, 85, 88, 91, 94, 97, 100, 103, 106, 109, \
                                                                    112, 115, 118, 121, 124, 127, 130, 133, 136, 139, \
                                                                    142, 145, 148, 151, 154, 157, 160, 163, 166, 169, \
                                                                    172, 175, 178, 181, 184, 187, 190, 193, 196, 199, \
                                                                    202, 205, 208, 211, 214, 217, 220, 223, 226, 229, \
                                                                    232, 235, 238, 241, 244, 247, 250, 253, 256, 259, \
                                                                    262, 265, 268, 271, 274, 277, 280, 283, 286, 289, \
                                                                    292, 295, 298, 301, 304, 307, 310, 313, 316, 319, \
                                                                    322, 325, 328, 331, 334, 337, 340, 343, 346, 349, \
                                                                    352, 355, 358, 361, 364, 367, 370, 373, 376, 379, \
                                                                    382, 385, 388, 391, 394, 397, 400, 403, 406, 409, \
                                                                    412, 415, 418, 421, 424, 427, 430, 433, 436, 439, \
                                                                    442, 445, 448, 451, 454, 457, 460, 463, 466, 469, \
                                                                    472, 475, 478, 481, 484, 487, 490, 493, 496, 499, \
                                                                    502, 505, 508, 511, 514, 517, 520, 523, 526, 529, \
                                                                    532, 535, 538, 541, 544, 547, 550, 553, 556, 559, \
                                                                    562, 565, 568, 571, 574, 577, 580, 583, 586, 589, \
                                                                    592, 595, 598, 601, 604, 607, 610, 613, 616, 619, \
                                                                    622, 625, 628, 631, 634, 637, 640, 643, 646, 649, \
                                                                    652, 655, 658, 737, 1492, 1592, 1968]},
    {'dtype': 'bf16', 'grp_dtype': np.int64, 'grad_y_shape': [200, 2560], 'group_idx': [
                0, 1, 2, 4, 7, 11, 13, 14, 15, 16, 16, 18, 19, 19, 19, 21, 23, 24, 25, 25, 27, 29, 30, 31, 31, 32, 32, 32, \
                32, 33, 34, 34, 34, 35, 38, 38, 40, 42, 42, 43, 45, 47, 48, 48, 51, 52, 52, 52, 55, 55, 56, 57, 58, 59, 60, \
                64, 65, 66, 66, 69, 70, 70, 73, 73, 73, 74, 76, 79, 80, 81, 81, 81, 83, 83, 84, 84, 85, 86, 87, 87, 89, 94, \
                94, 95, 95, 96, 97, 97, 99, 101, 102, 102, 103, 103, 104, 105, 106, 106, 106, 107, 108, 109, 110, 110, 112, \
                112, 113, 113, 113, 114, 114, 115, 116, 118, 118, 121, 121, 123, 126, 129, 129, 129, 132, 134, 135, 136, 136, \
                136, 137, 139, 139, 140, 141, 141, 142, 142, 143, 143, 143, 144, 145, 145, 146, 146, 148, 149, 150, 152, 152, \
                153, 154, 155, 159, 160, 160, 161, 161, 161, 163, 164, 164, 164, 165, 165, 165, 166, 168, 168, 172, 174, 174, \
                174, 174, 175, 177, 177, 178, 179, 180, 182, 182, 183, 183, 185, 185, 187, 189, 189, 190, 190, 192, 192, 194, \
                194, 194, 197, 198, 198, 199, 200]},
    {'dtype': 'fp16', 'grp_dtype': np.int64, 'grad_y_shape': [200, 2560], 'group_idx': [
                0, 1, 2, 4, 7, 11, 13, 14, 15, 16, 16, 18, 19, 19, 19, 21, 23, 24, 25, 25, 27, 29, 30, 31, 31, 32, 32, 32, \
                32, 33, 34, 34, 34, 35, 38, 38, 40, 42, 42, 43, 45, 47, 48, 48, 51, 52, 52, 52, 55, 55, 56, 57, 58, 59, 60, \
                64, 65, 66, 66, 69, 70, 70, 73, 73, 73, 74, 76, 79, 80, 81, 81, 81, 83, 83, 84, 84, 85, 86, 87, 87, 89, 94, \
                94, 95, 95, 96, 97, 97, 99, 101, 102, 102, 103, 103, 104, 105, 106, 106, 106, 107, 108, 109, 110, 110, 112, \
                112, 113, 113, 113, 114, 114, 115, 116, 118, 118, 121, 121, 123, 126, 129, 129, 129, 132, 134, 135, 136, 136, \
                136, 137, 139, 139, 140, 141, 141, 142, 142, 143, 143, 143, 144, 145, 145, 146, 146, 148, 149, 150, 152, 152, \
                153, 154, 155, 159, 160, 160, 161, 161, 161, 163, 164, 164, 164, 165, 165, 165, 166, 168, 168, 172, 174, 174, \
                174, 174, 175, 177, 177, 178, 179, 180, 182, 182, 183, 183, 185, 185, 187, 189, 189, 190, 190, 192, 192, 194, \
                194, 194, 197, 198, 198, 199, 200]},
    {'dtype': 'fp32', 'grp_dtype': np.int64, 'grad_y_shape': [200, 2560], 'group_idx': [
                0, 1, 2, 4, 7, 11, 13, 14, 15, 16, 16, 18, 19, 19, 19, 21, 23, 24, 25, 25, 27, 29, 30, 31, 31, 32, 32, 32, \
                32, 33, 34, 34, 34, 35, 38, 38, 40, 42, 42, 43, 45, 47, 48, 48, 51, 52, 52, 52, 55, 55, 56, 57, 58, 59, 60, \
                64, 65, 66, 66, 69, 70, 70, 73, 73, 73, 74, 76, 79, 80, 81, 81, 81, 83, 83, 84, 84, 85, 86, 87, 87, 89, 94, \
                94, 95, 95, 96, 97, 97, 99, 101, 102, 102, 103, 103, 104, 105, 106, 106, 106, 107, 108, 109, 110, 110, 112, \
                112, 113, 113, 113, 114, 114, 115, 116, 118, 118, 121, 121, 123, 126, 129, 129, 129, 132, 134, 135, 136, 136, \
                136, 137, 139, 139, 140, 141, 141, 142, 142, 143, 143, 143, 144, 145, 145, 146, 146, 148, 149, 150, 152, 152, \
                153, 154, 155, 159, 160, 160, 161, 161, 161, 163, 164, 164, 164, 165, 165, 165, 166, 168, 168, 172, 174, 174, \
                174, 174, 175, 177, 177, 178, 179, 180, 182, 182, 183, 183, 185, 185, 187, 189, 189, 190, 190, 192, 192, 194, \
                194, 194, 197, 198, 198, 199, 200]},
    {'dtype': 'fp16', 'grp_dtype': np.int64, 'grad_y_shape': [1, 2560], 'group_idx': [1]},
    {'dtype': 'fp32', 'grp_dtype': np.int64, 'grad_y_shape': [1, 2560], 'group_idx': [1]},
    {'dtype': 'bf16', 'grp_dtype': np.int64, 'grad_y_shape': [100, 256], 'group_idx': [40, 40, 20], 'group_idx_type': 1},
    {'dtype': 'fp16', 'grp_dtype': np.int32, 'grad_y_shape': [1968, 458], 'group_idx': [737, 737, 494], 'group_idx_type': 1},
    {'dtype': 'fp32', 'grp_dtype': np.int64, 'grad_y_shape': [200, 2560], 'group_idx': [
                1,  0,  0,  1,  0,  1,  0,  1,  0,  1,  0,  1,  0,  0,  1,  1,  1, \
                1,  1,  0,  0,  0,  1,  0,  1,  1,  0,  0,  1,  1,  0,  1,  0,  1, \
                0,  1,  1,  1,  0,  0,  1,  0,  1,  0,  0,  1,  0,  1,  1,  1,  1, \
                1,  1,  1,  0,  1,  0,  0,  0,  1,  0,  0,  0,  1,  1,  1,  0,  1, \
                0,  0,  1,  1,  1,  0,  0,  0,  0,  1,  0,  1,  2,  1,  1,  0,  0, \
                0,  1,  1,  1,  1,  1,  1,  1,  0,  0,  0,  1,  0,  0,  1,  0,  1, \
                0,  1,  1,  1,  0,  1,  1,  1,  1,  0,  1,  1,  0,  1,  0,  0,  1, \
                1,  1,  1,  0,  1,  0,  1,  1,  0,  0,  0,  1,  1,  1,  0,  0,  0, \
                0,  0,  2,  0,  1,  0,  1,  0,  0,  0,  1,  0,  1,  0,  1,  0,  1, \
                0,  0,  0,  1,  0,  0,  1,  1,  0,  0,  1,  0,  1,  1,  0,  1,  0, \
                1,  1,  0,  0,  1,  1,  1,  1,  0,  1,  1,  1,  1,  0,  1,  0,  0, \
                0,  1,  0,  0,  1,  1,  1,  0,  0,  1,  1,  0, 92], 'group_idx_type': 1},
    {'dtype': 'fp32', 'grp_dtype': np.int32, 'grad_y_shape': [200, 2560], 'group_idx': [
                1,  0,  0,  1,  0,  1,  0,  1,  0,  1,  0,  1,  0,  0,  1,  1,  1, \
                1,  1,  0,  0,  0,  1,  0,  1,  1,  0,  0,  1,  1,  0,  1,  0,  1, \
                0,  1,  1,  1,  0,  0,  1,  0,  1,  0,  0,  1,  0,  1,  1,  1,  1, \
                1,  1,  1,  0,  1,  0,  0,  0,  1,  0,  0,  0,  1,  1,  1,  0,  1, \
                0,  0,  1,  1,  1,  0,  0,  0,  0,  1,  0,  1,  2,  1,  1,  0,  0, \
                0,  1,  1,  1,  1,  1,  1,  1,  0,  0,  0,  1,  0,  0,  1,  0,  1, \
                0,  1,  1,  1,  0,  1,  1,  1,  1,  0,  1,  1,  0,  1,  0,  0,  1, \
                1,  1,  1,  0,  1,  0,  1,  1,  0,  0,  0,  1,  1,  1,  0,  0,  0, \
                0,  0,  2,  0,  1,  0,  1,  0,  0,  0,  1,  0,  1,  0,  1,  0,  1, \
                0,  0,  0,  1,  0,  0,  1,  1,  0,  0,  1,  0,  1,  1,  0,  1,  0, \
                1,  1,  0,  0,  1,  1,  1,  1,  0,  1,  1,  1,  1,  0,  1,  0,  0, \
                0,  1,  0,  0,  1,  1,  1,  0,  0,  1,  1,  0, 92], 'group_idx_type': 1},
]

def grouped_bias_add_grad(grad_y, group_idx, group_idx_type):
    if group_idx is None:
        grad_bias = torch.sum(grad_y, 1)
        return grad_bias
    
    if group_idx_type == 0:
        for i, num in enumerate(group_idx):
            if i == 0:
                x = grad_y[:num, :]
                grad_bias = torch.sum(x, 0, keepdim=True)
            else:
                x = grad_y[group_idx[i-1]:num, :]
                tmp = torch.sum(x, 0, keepdim=True)
                grad_bias = torch.cat((grad_bias, tmp), 0)
    else:
        group_idx = group_idx.tolist()
        grad_y = torch.split(grad_y, group_idx, 0)
        out_list = []
        for i in range(len(group_idx)):
            out_list.append(torch.sum(grad_y[i], 0, keepdim=True))
        grad_bias = torch.cat(out_list, 0)
    return grad_bias

def gen_data_and_golden(case):
    np_type_dict = {
        "fp32": np.float32,
        "fp16": np.float16,
        "bf16": bf16
    }
    grad_y_dtype = case["dtype"].lower()
    grad_y_shape = case["grad_y_shape"]
    group_idx = case["group_idx"] if "group_idx" in case else "none"
    group_idx_dtype = case['grp_dtype']
    group_idx_type = case["group_idx_type"] if "group_idx_type" in case else 0

    np_type = np_type_dict.get(grad_y_dtype)
    grad_y = np.random.randn(*grad_y_shape).astype(np_type)
    grad_y.tofile("./grad_y.bin")
    if group_idx == 'none':
        group_idx = None
    else:
        np.array(group_idx).astype(group_idx_dtype).tofile(f"group_idx.bin")
    if grad_y_dtype != 'fp32' :
        grad_y = grad_y.astype(np.float32)
    grad_y = torch.from_numpy(grad_y)
    grad_bias = grouped_bias_add_grad(grad_y, group_idx, group_idx_type)
    grad_bias.numpy().astype(np_type).tofile(f"out_golden.bin")

if __name__ == "__main__":
    # 清理bin文件
    os.system("rm -rf *.bin")
    idx = int(sys.argv[1])
    case = case_list[idx]
    gen_data_and_golden(case)

